/**
 * @file cuda_runtime_template_wrapper.h
 *
 * @brief file is a wrapper to cuda runtime template C++.
 *
 */

#ifndef __INCLUDE_CUDA_RUNTIME_TEMPLATE_WRAPPER_H__
#define __INCLUDE_CUDA_RUNTIME_TEMPLATE_WRAPPER_H__

#include "mc_runtime_api.h"
#include "mc_runtime_types.h"
#include <stdint.h>

#ifdef __cplusplus
#define __dv(x) = x
#else
#define __dv(x)
#endif

/* The new module runtime C++ template declaration, add brief introductions please. */

template <class T>
static __inline__ __host__ mcError_t wcudaLaunchKernel(const T *func, dim3 gridDim, dim3 blockDim,
                                                       void **args, size_t sharedMem = 0,
                                                       mcStream_t stream = 0)
{
    return wcudaLaunchKernel((const void *)func, gridDim, blockDim, args, sharedMem, stream);
}

template <class T>
static __inline__ __host__ mcError_t wcudaLaunchCooperativeKernel(const T *func, dim3 gridDim,
                                                                  dim3 blockDim, void **args,
                                                                  size_t sharedMem  = 0,
                                                                  mcStream_t stream = 0)
{
    return wcudaLaunchCooperativeKernel((const void *)func, gridDim, blockDim, args, sharedMem,
                                        stream);
}

template <class T>
static __inline__ __host__ mcError_t wcudaLaunchKernelExC(const mcLaunchConfig_t *config,
                                                          const T *func, void **args)
{
    return wcudaLaunchKernelExC(config, (const void *)func, args);
}

/* Occupancy runtime C++ template declaration */
template <class T>
static __inline__ __host__ mcError_t wcudaOccupancyAvailableDynamicSMemPerBlock(
    size_t *dynamicSmemSize, T func, int numBlocks, int blockSize)

{
    return wcudaOccupancyAvailableDynamicSMemPerBlock(dynamicSmemSize, (const void *)func,
                                                      numBlocks, blockSize);
}

template <class T>
inline mcError_t wcudaOccupancyMaxActiveBlocksPerMultiprocessor(int *numBlocks, T f, int blockSize,
                                                                size_t dynSharedMemPerBlk)
{
    return wcudaOccupancyMaxActiveBlocksPerMultiprocessor(numBlocks, (const void *)f, blockSize,
                                                          dynSharedMemPerBlk);
}

template <class T>
static __inline__ __host__ mcError_t wcudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
    int *numBlocks, T func, int blockSize, size_t dynamicSMemSize, unsigned int flags)
{
    return wcudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
        numBlocks, (const void *)func, blockSize, dynamicSMemSize, flags);
}

template <typename T>
static __inline__ __host__ mcError_t wcudaOccupancyMaxPotentialBlockSize(
    int *gridSize, int *blockSize, T f, size_t dynSharedMemPerBlk = 0, int blockSizeLimit = 0)
{
    return wcudaOccupancyMaxPotentialBlockSize(gridSize, blockSize, (const void *)f,
                                               dynSharedMemPerBlk, blockSizeLimit);
}

template <typename T>
static __inline__ __host__ mcError_t wcudaOccupancyMaxPotentialBlockSizeWithFlags(
    int *gridSize, int *blockSize, T f, size_t dynSharedMemPerBlk = 0, int blockSizeLimit = 0,
    unsigned int flags = 0)
{
    return wcudaOccupancyMaxPotentialBlockSizeWithFlags(gridSize, blockSize, (const void *)f,
                                                        dynSharedMemPerBlk, blockSizeLimit, flags);
}

template <typename UnaryFunction, class T>
static inline mcError_t wcudaOccupancyMaxPotentialBlockSizeVariableSMemWithFlags(
    int *gridSize, int *blockSize, T f, UnaryFunction blockSizeToDynamicSMemSize,
    int blockSizeLimit = 0, unsigned int flags = 0)
{
    mcError_t status;

    /* Device and function properties*/
    int device;
    mcFuncAttributes attr;

    /* Limits*/
    int maxThreadsPerMultiProcessor;
    int warpSize;
    int devMaxThreadsPerBlock;
    int multiProcessorCount;
    int funcMaxThreadsPerBlock;
    int occupancyLimit;
    int granularity;

    /* Recorded maximum*/
    int maxBlockSize = 0;
    int numBlocks    = 0;
    int maxOccupancy = 0;

    /*Temporary*/
    int blockSizeToTryAligned;
    int blockSizeToTry;
    int blockSizeLimitAligned;
    int occupancyInBlocks;
    int occupancyInThreads;
    size_t dynamicSMemSize;

    /**
     * Check user input
     */

    if (!gridSize || !blockSize || !f) {
        return mcErrorInvalidValue;
    }

    /**
     *  Obtain device and function properties
     */

    status = mcGetDevice(&device);
    if (status != mcSuccess) {
        return status;
    }

    status = mcDeviceGetAttribute(&maxThreadsPerMultiProcessor,
                                  mcDeviceAttributeMaxThreadsPerMultiProcessor, device);
    if (status != mcSuccess) {
        return status;
    }

    status = mcDeviceGetAttribute(&warpSize, mcDeviceAttributeWarpSize, device);
    if (status != mcSuccess) {
        return status;
    }

    status =
        mcDeviceGetAttribute(&devMaxThreadsPerBlock, mcDeviceAttributeMaxThreadsPerBlock, device);
    if (status != mcSuccess) {
        return status;
    }

    status =
        mcDeviceGetAttribute(&multiProcessorCount, mcDeviceAttributeMultiProcessorCount, device);
    if (status != mcSuccess) {
        return status;
    }

    status = mcFuncGetAttributes(&attr, reinterpret_cast<const void *>(f));
    if (status != mcSuccess) {
        return status;
    }

    funcMaxThreadsPerBlock = attr.maxThreadsPerBlock;

    /**
     * Try each block size, and pick the block size with maximum occupancy
     */

    occupancyLimit  = maxThreadsPerMultiProcessor;
    granularity     = warpSize;
    dynamicSMemSize = blockSizeToDynamicSMemSize(blockSizeLimit);

    if (blockSizeLimit == 0) {
        blockSizeLimit = devMaxThreadsPerBlock;
    }

    if (devMaxThreadsPerBlock < blockSizeLimit) {
        blockSizeLimit = devMaxThreadsPerBlock;
    }

    if (funcMaxThreadsPerBlock < blockSizeLimit) {
        blockSizeLimit = funcMaxThreadsPerBlock;
    }

    blockSizeLimitAligned = ((blockSizeLimit + (granularity - 1)) / granularity) * granularity;

    for (blockSizeToTryAligned = blockSizeLimitAligned; blockSizeToTryAligned > 0;
         blockSizeToTryAligned -= granularity) {
        /**
         * This is needed for the first iteration, because blockSizeLimitAligned could be greater
         * than blockSizeLimit
         */
        if (blockSizeLimit < blockSizeToTryAligned) {
            blockSizeToTry = blockSizeLimit;
        } else {
            blockSizeToTry = blockSizeToTryAligned;
        }

        status = mcOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
            &occupancyInBlocks, reinterpret_cast<const void *>(f), blockSizeToTry, dynamicSMemSize,
            flags);

        if (status != mcSuccess) {
            return status;
        }

        occupancyInThreads = blockSizeToTry * occupancyInBlocks;

        if (occupancyInThreads > maxOccupancy) {
            maxBlockSize = blockSizeToTry;
            numBlocks    = occupancyInBlocks;
            maxOccupancy = occupancyInThreads;
        }

        /*Early out if we have reached the maximum*/
        if (occupancyLimit == maxOccupancy) {
            break;
        }
    }

    /**
     * Return best available
     */

    /* Suggested min grid size to achieve a full machine launch*/
    *gridSize  = numBlocks * multiProcessorCount;
    *blockSize = maxBlockSize;

    return status;
}

template <typename UnaryFunction, class T>
static __inline__ __host__ mcError_t wcudaOccupancyMaxPotentialBlockSizeVariableSMem(
    int *gridSize, int *blockSize, T f, UnaryFunction blockSizeToDynamicSMemSize,
    int blockSizeLimit = 0)
{
    return wcudaOccupancyMaxPotentialBlockSizeVariableSMemWithFlags(
        gridSize, blockSize, f, blockSizeToDynamicSMemSize, blockSizeLimit, mcOccupancyDefault);
}

/*  Memory Management */
template <class T> static __inline__ __host__ mcError_t wcudaMalloc(T **ptr, size_t sizeBytes)
{
    return wcudaMalloc((void **)ptr, sizeBytes);
}

template <class T>
static __inline__ __host__ mcError_t wcudaMallocHost(T **ptr, size_t size, unsigned int flags = 0)
{
    return wcudaMallocHost((void **)ptr, size, flags);
}

template <class T> static inline mcError_t wcudaFree(T *ptr) { return wcudaFree((void *)ptr); }

template <class T> static inline mcError_t wcudaFreeHost(T *ptr)
{
    return wcudaFreeHost((void *)ptr);
}

template <class T>
static inline __host__ mcError_t wcudaHostAlloc(T **pHost, size_t size, unsigned int flags)
{
    return wcudaHostAlloc((void **)pHost, size, flags);
}

template <class T>
static inline mcError_t wcudaMallocManaged(T **devPtr, size_t size,
                                           unsigned int flags __dv(mcMemAttachGlobal))
{
    return wcudaMallocManaged((void **)devPtr, size, flags);
}

template <class T>
static __inline__ __host__ mcError_t wcudaStreamAttachMemAsync(
    mcStream_t stream, T *devPtr, size_t length = 0, unsigned int flags = mcMemAttachSingle)
{
    return wcudaStreamAttachMemAsync(stream, (void *)devPtr, length, flags);
}

template <class T>
static inline __host__ mcError_t wcudaGetSymbolAddress(void **devPtr, const T &symbol)
{
    return wcudaGetSymbolAddress(devPtr, (const void *)&symbol);
}

template <class T>
static inline __host__ mcError_t wcudaGetSymbolSize(size_t *size, const T &symbol)
{
    return wcudaGetSymbolSize(size, (const void *)&symbol);
}

template <class T>
static inline __host__ mcError_t
wcudaMemcpyFromSymbol(void *dst, const T &symbol, size_t count, size_t offset = 0,
                      enum _mcMemcpyKind kind = mcMemcpyDeviceToHost)
{
    return wcudaMemcpyFromSymbol(dst, (const void *)&symbol, count, offset, kind);
}

template <class T>
static inline __host__ mcError_t
wcudaMemcpyFromSymbolAsync(void *dst, const T &symbol, size_t count, size_t offset = 0,
                           enum _mcMemcpyKind kind = mcMemcpyDeviceToHost, mcStream_t stream = 0)
{
    return wcudaMemcpyFromSymbolAsync(dst, (const void *)&symbol, count, offset, kind, stream);
}

template <class T>
static inline __host__ mcError_t wcudaMemcpyToSymbol(const T &symbol, const void *src, size_t count,
                                                     size_t offset           = 0,
                                                     enum _mcMemcpyKind kind = mcMemcpyHostToDevice)
{
    return wcudaMemcpyToSymbol((const void *)&symbol, src, count, offset, kind);
}

template <class T>
static inline __host__ mcError_t
wcudaMemcpyToSymbolAsync(const T &symbol, const void *src, size_t count, size_t offset = 0,
                         enum _mcMemcpyKind kind = mcMemcpyHostToDevice, mcStream_t stream = 0)
{
    return wcudaMemcpyToSymbolAsync((const void *)&symbol, src, count, offset, kind, stream);
}

template <class T>
static __inline__ __host__ mcError_t wcudaMallocPitch(T **devPtr, size_t *pitch, size_t width,
                                                      size_t heigh)
{
    return wcudaMallocPitch((void **)devPtr, pitch, width, heigh);
}

template <class T>
static inline __host__ mcError_t wcudaHostGetDevicePointer(T **pDevice, void *pHost,
                                                           unsigned int flags)
{
    return wcudaHostGetDevicePointer((void **)pDevice, pHost, flags);
}

template <class T>
static inline __host__ mcError_t wcudaMemcpy2D(T *dst, size_t dpitch, const T *src, size_t spitch,
                                               size_t width, size_t height, mcMemcpyKind kind)
{
    return wcudaMemcpy2D((void *)dst, dpitch, (void *)src, spitch, width, height, kind);
}

template <class T>
static inline __host__ mcError_t wcudaMemcpy2DAsync(T *dst, size_t dpitch, const T *src,
                                                    size_t spitch, size_t width, size_t height,
                                                    mcMemcpyKind kind, mcStream_t stream = 0)
{
    return wcudaMemcpy2DAsync((void *)dst, dpitch, (void *)src, spitch, width, height, kind,
                              stream);
}

template <class T>
static inline __host__ mcError_t wcuMemcpyHtoD(mcDeviceptr_t dstDevice, const T *srcHost,
                                               size_t ByteCount)
{
    return wcuMemcpyHtoD(dstDevice, (const void *)srcHost, ByteCount);
}

template <class T>
static inline __host__ mcError_t wcuMemcpyHtoDAsync(mcDeviceptr_t dstDevice, const T *srcHost,
                                                    size_t ByteCount, mcStream_t hStream)
{
    return wcuMemcpyHtoDAsync(dstDevice, (const void *)srcHost, ByteCount, hStream);
}

template <class T>
static inline __host__ mcError_t wcuMemcpyDtoH(T *dstHost, mcDeviceptr_t srcDevice,
                                               size_t ByteCount)
{
    return wcuMemcpyDtoH((void *)dstHost, srcDevice, ByteCount);
}

template <class T>
static inline __host__ mcError_t wcuMemcpyDtoHAsync(T *dstHost, mcDeviceptr_t srcDevice,
                                                    size_t ByteCount, mcStream_t hStream)
{
    return wcuMemcpyDtoHAsync((void *)dstHost, srcDevice, ByteCount, hStream);
}

/* SOMA Stream Ordered Memory Allocator */
template <class T>
static inline mcError_t wcudaMallocAsync(T **devPtr, size_t size, mcStream_t hStream)
{
    return wcudaMallocAsync((void **)devPtr, size, hStream);
}
template <class T>
static inline mcError_t wcudaMallocAsync(T **devPtr, size_t size, mcMemPool_t memPool,
                                         mcStream_t stream)
{
    return wcudaMallocAsync((void **)devPtr, size, memPool, stream);
}
template <class T>
static inline mcError_t wcudaMallocFromPoolAsync(T **devPtr, size_t size, mcMemPool_t memPool,
                                                 mcStream_t stream)
{
    return wcudaMallocFromPoolAsync((void **)devPtr, size, memPool, stream);
}

/* Execution Control runtime C++ template declaration */
template <class T>
static __inline__ __host__ mcError_t wcudaFuncSetCacheConfig(T *func,
                                                             enum mcFuncCache_t cacheConfig)
{
    return wcudaFuncSetCacheConfig((const void *)func, cacheConfig);
}

template <class T>
static __inline__ __host__ mcError_t wcudaFuncSetSharedMemConfig(T *func, mcSharedMemConfig config)
{
    return wcudaFuncSetSharedMemConfig((const void *)func, config);
}

template <class T>
static __inline__ __host__ mcError_t wcudaFuncGetAttributes(mcFuncAttributes *attr, T *entry)
{
    return wcudaFuncGetAttributes(attr, (const void *)entry);
}

template <class T>
static __inline__ __host__ mcError_t wcudaFuncSetAttribute(T *entry, mcFuncAttribute attr,
                                                           int value)
{
    return wcudaFuncSetAttribute((const void *)entry, attr, value);
}

/* graph function */
template <class T>
static __inline__ __host__ mcError_t wcudaGraphAddMemcpyNodeToSymbol(
    mcGraphNode_t *pGraphNode, mcGraph_t graph, const mcGraphNode_t *pDependencies,
    size_t numDependencies, const T &symbol, const void *src, size_t count, size_t offset,
    mcMemcpyKind kind)
{
    return wcudaGraphAddMemcpyNodeToSymbol(pGraphNode, graph, pDependencies, numDependencies,
                                           (const void *)&symbol, src, count, offset, kind);
}

template <class T>
static __inline__ __host__ mcError_t wcudaGraphAddMemcpyNodeFromSymbol(
    mcGraphNode_t *pGraphNode, mcGraph_t graph, const mcGraphNode_t *pDependencies,
    size_t numDependencies, void *dst, const T &symbol, size_t count, size_t offset,
    mcMemcpyKind kind)
{
    return wcudaGraphAddMemcpyNodeFromSymbol(pGraphNode, graph, pDependencies, numDependencies, dst,
                                             (const void *)&symbol, count, offset, kind);
}

template <class T>
static __inline__ __host__ mcError_t
wcudaGraphMemcpyNodeSetParamsToSymbol(mcGraphNode_t node, const T &symbol, const void *src,
                                      size_t count, size_t offset, mcMemcpyKind kind)
{
    return wcudaGraphMemcpyNodeSetParamsToSymbol(node, (const void *)&symbol, src, count, offset,
                                                 kind);
}

template <class T>
static __inline__ __host__ mcError_t wcudaGraphMemcpyNodeSetParamsFromSymbol(
    mcGraphNode_t node, void *dst, const T &symbol, size_t count, size_t offset, mcMemcpyKind kind)
{
    return wcudaGraphMemcpyNodeSetParamsFromSymbol(node, dst, (const void *)&symbol, count, offset,
                                                   kind);
}

template <class T>
static __inline__ __host__ mcError_t wcudaGraphExecMemcpyNodeSetParamsToSymbol(
    mcGraphExec_t hGraphExec, mcGraphNode_t node, const T &symbol, const void *src, size_t count,
    size_t offset, mcMemcpyKind kind)
{
    return wcudaGraphExecMemcpyNodeSetParamsToSymbol(hGraphExec, node, (const void *)&symbol, src,
                                                     count, offset, kind);
}

template <class T>
static __inline__ __host__ mcError_t wcudaGraphExecMemcpyNodeSetParamsFromSymbol(
    mcGraphExec_t hGraphExec, mcGraphNode_t node, void *dst, const T &symbol, size_t count,
    size_t offset, mcMemcpyKind kind)
{
    return wcudaGraphExecMemcpyNodeSetParamsFromSymbol(hGraphExec, node, dst, (const void *)&symbol,
                                                       count, offset, kind);
}

template <class T>
static __inline__ __host__ mcError_t wcudaUserObjectCreate(mcUserObject_t *object_out,
                                                           T *objectToWrap,
                                                           unsigned int initialRefcount,
                                                           unsigned int flags)
{
    return wcudaUserObjectCreate(
        object_out, objectToWrap, [](void *vpObj) { delete reinterpret_cast<T *>(vpObj); },
        initialRefcount, flags);
}

template <class T>
static __inline__ __host__ mcError_t wcudaUserObjectCreate(mcUserObject_t *object_out,
                                                           T *objectToWrap,
                                                           unsigned int initialRefcount,
                                                           enum mcUserObjectFlags flags)
{
    return wcudaUserObjectCreate(object_out, objectToWrap, initialRefcount, (unsigned int)flags);
}

template <class T, int dim, enum mcTextureReadMode readMode>
static __inline__ __host__ mcError_t wcudaBindTexture(size_t *offset,
                                                      const struct texture<T, dim, readMode> &tex,
                                                      const void *devPtr,
                                                      const mcChannelFormatDesc &desc,
                                                      size_t size = UINT_MAX)
{
    return wcudaBindTexture(offset, &tex, devPtr, &desc, size);
}

template <class T, int dim, enum mcTextureReadMode readMode>
static __inline__ __host__ mcError_t wcudaBindTexture(size_t *offset,
                                                      const struct texture<T, dim, readMode> &tex,
                                                      const void *devPtr, size_t size = UINT_MAX)
{
    return wcudaBindTexture(offset, tex, devPtr, tex.channelDesc, size);
}

template <class T, int dim, enum mcTextureReadMode readMode>
static __inline__ __host__ mcError_t wcudaBindTexture2D(size_t *offset,
                                                        const struct texture<T, dim, readMode> &tex,
                                                        const void *devPtr,
                                                        const mcChannelFormatDesc &desc,
                                                        size_t width, size_t height, size_t pitch)
{
    return wcudaBindTexture2D(offset, &tex, devPtr, &desc, width, height, pitch);
}

template <class T, int dim, enum mcTextureReadMode readMode>
static __inline__ __host__ mcError_t wcudaBindTexture2D(size_t *offset,
                                                        const struct texture<T, dim, readMode> &tex,
                                                        const void *devPtr, size_t width,
                                                        size_t height, size_t pitch)
{
    return wcudaBindTexture2D(offset, &tex, devPtr, &tex.channelDesc, width, height, pitch);
}

template <class T, int dim, enum mcTextureReadMode readMode>
static __inline__ __host__ mcError_t
wcudaBindTextureToArray(const struct texture<T, dim, readMode> &tex, mcArray_const_t array,
                        const mcChannelFormatDesc &desc)
{
    return wcudaBindTextureToArray(&tex, array, &desc);
}

template <class T, int dim, enum mcTextureReadMode readMode>
static __inline__ __host__ mcError_t
wcudaBindTextureToArray(const struct texture<T, dim, readMode> &tex, mcArray_const_t array)
{
    mcChannelFormatDesc desc;
    mcError_t err = wcudaGetChannelDesc(&desc, array);

    return err == mcSuccess ? wcudaBindTextureToArray(tex, array, desc) : err;
}

template <class T, int dim, enum mcTextureReadMode readMode>
static __inline__ __host__ mcError_t wcudaBindTextureToMipmappedArray(
    const struct texture<T, dim, readMode> &tex, mcMipmappedArray_const_t mipmappedArray,
    const mcChannelFormatDesc &desc)
{
    return wcudaBindTextureToMipmappedArray(&tex, mipmappedArray, &desc);
}

template <class T, int dim, enum mcTextureReadMode readMode>
static __inline__ __host__ mcError_t wcudaBindTextureToMipmappedArray(
    const struct texture<T, dim, readMode> &tex, mcMipmappedArray_const_t mipmappedArray)
{
    mcChannelFormatDesc desc;
    mcArray_t levelArray;
    mcError_t err = wcudaGetMipmappedArrayLevel(&levelArray, mipmappedArray, 0);

    if (err != mcSuccess) {
        return err;
    }
    err = wcudaGetChannelDesc(&desc, levelArray);

    return err == mcSuccess ? wcudaBindTextureToMipmappedArray(tex, mipmappedArray, desc) : err;
}

template <class T, int dim, enum mcTextureReadMode readMode>
static __inline__ __host__ mcError_t wcudaUnbindTexture(const struct texture<T, dim, readMode> &tex)
{
    return wcudaUnbindTexture(&tex);
}

template <class T, int dim, enum mcTextureReadMode readMode>
static __inline__ __host__ mcError_t
wcudaGetTextureAlignmentOffset(size_t *offset, const struct texture<T, dim, readMode> &tex)
{
    return wcudaGetTextureAlignmentOffset(offset, &tex);
}

template <class T, int dim>
static __inline__ __host__ mcError_t wcudaBindSurfaceToArray(const struct surface<T, dim> &surf,
                                                             const mcArray *array,
                                                             const mcChannelFormatDesc &desc)
{
    return wcudaBindSurfaceToArray(&surf, array, &desc);
}

template <class T, int dim>
static __inline__ __host__ mcError_t wcudaBindSurfaceToArray(const struct surface<T, dim> &surf,
                                                             const mcArray *array)
{
    return wcudaBindSurfaceToArray(&surf, array, &surf.channelDesc);
}

#endif /*  __INCLUDE_CUDA_RUNTIME_TEMPLATE_WRAPPER_H__  */